{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vfch6fQlnG05"
   },
   "source": [
    "# Guide to grad samplers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RKnAlq0AnQCd"
   },
   "source": [
    "DP-SGD guarantees privacy of every sample used in the training. In order to realize this, we have to bound the sensitivity of every sample, and in order to do that, we have to clip the gradient of every sample. Unfortunately, pytorch doesn't maintain the gradients of individual samples in a batch and only exposes the aggregated gradients of all the samples in a batch via the [`.grad` attribute](https://pytorch.org/docs/stable/generated/torch.Tensor.grad.html). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ERr4O-dVXMne"
   },
   "source": [
    "The easiest way to get what we want is to train with batch size of 1 as follows:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TDj9uJZd89O7"
   },
   "outputs": [],
   "source": [
    "optimizer = torch.optim.SGD(lr=0.01)\n",
    "for x, y i DataLoader(train_dataset, batch_size=128):\n",
    "  # Run samples one-by-one to get per-sample gradients\n",
    "  for x_i, y_i in zip(x, y):\n",
    "    y_hat_i = model(x_i)\n",
    "    loss = criterion(y_hat_i, y_i)\n",
    "    loss.backward()\n",
    "  \n",
    "    # Clip each parameter's per-sample gradient\n",
    "    for p in model.parameters():\n",
    "      per_sample_grad = p.grad.detach().clone()\n",
    "      torch.nn.utils.clip_grad_norm(per_sample_grad, max_norm=1.0)\n",
    "      p.accumulated_grads.append(per_sample_grad)\n",
    "    model.zero_grad(). # p.grad is accumulative, so we need to manually reset\n",
    "  \n",
    "  # Aggregate clipped gradients of all samples in a batch, and add DP noise\n",
    "  for p in model.parameters():\n",
    "    p.grad = accumulate_and_noise(p.accumulated_grads, dp_paramters)\n",
    "  \n",
    "  optimizer.step()\n",
    "  optimizer.zero_grad()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dAULBrBeXRB1"
   },
   "source": [
    "This, however, would be a criminal waste of time and resources, and we will be leaving all the vectorized optimizations on the sidelines.\n",
    "\n",
    "GradSampleModule is an `nn.Module` replacement offered by Opacus to solve the above problem. In addition to the `.grad` attribute, the parameters of this module will also have a `.grad_sample` attribute.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DW4DumLD0DWW"
   },
   "source": [
    "## `GradSampleModule` internals\n",
    "For most modules, Opacus provides a function (aka grad_sampler) that essentially computes the per-sample-gradients of a batch by -- more or less -- doing the backpropagation \"by hand\".\n",
    "\n",
    "`GradSampleModule` is a wrapper around the existing `nn.Module`s. It attaches the above function to the modules it wraps using [backward hooks](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?#torch.nn.Module.register_backward_hook). It also provides other auxiliary methods such as validation, utilities to add/remove/set/reset `grad_sample`, utilities to `attach/remove` hooks, etc.\n",
    "\n",
    "TL;DR: grad_samplers contain the logic to compute the gradients given the activations and backpropagated gradients, and the `GradSampleModule` takes care of everything else by attaching the grad_samplers to the right modules and exposes a simple/minimal interface to the user.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PvsAyct8cL-I"
   },
   "source": [
    "Let's see an example. Say you want to get a GradSampleModule version of `nn.Linear`. This is what you would have to do:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "yKv1fKKRckup",
    "outputId": "8b4714e9-1cde-4345-92fa-7c8d90928abb"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Before wrapping: Linear(in_features=42, out_features=2, bias=True)\n",
      "After wrapping : GradSample(Linear(in_features=42, out_features=2, bias=True))\n"
     ]
    }
   ],
   "source": [
    "import torch.nn as nn\n",
    "from opacus.grad_sample import GradSampleModule\n",
    "\n",
    "lin_mod = nn.Linear(42,2)\n",
    "print(f\"Before wrapping: {lin_mod}\")\n",
    "\n",
    "gs_lin_mod = GradSampleModule(lin_mod)\n",
    "print(f\"After wrapping : {gs_lin_mod}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dHCBtWUYLdOz"
   },
   "source": [
    "That's it!\n",
    "`GradSampleModule` wraps your linear module with all the goodies and you can use this module as a drop-in replacement."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HlXihSpLf5gL"
   },
   "source": [
    "### grad_sampler internals\n",
    "Now, what does the grad_sampler for the above `nn.Linear` layer look like? It looks as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Fgnf6V6Fm730"
   },
   "outputs": [],
   "source": [
    "def compute_linear_grad_sample(\n",
    "    layer: nn.Linear, activations: torch.Tensor, backprops: torch.Tensor\n",
    ") -> Dict[nn.Parameter, torch.Tensor]:\n",
    "    \"\"\"\n",
    "    Computes per sample gradients for ``nn.Linear`` layer\n",
    "    Args:\n",
    "        layer: Layer\n",
    "        activations: Activations\n",
    "        backprops: Backpropagations\n",
    "    \"\"\"\n",
    "    gs = torch.einsum(\"n...i,n...j->nij\", backprops, activations)\n",
    "    ret = {layer.weight: gs}\n",
    "    if layer.bias is not None:\n",
    "        ret[layer.bias] = torch.einsum(\"n...k->nk\", backprops)\n",
    "\n",
    "    return ret"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "y80AplsWgI5M"
   },
   "source": [
    "The above grad_sampler takes in the activations and backpropagated gradients, computes the per-sample-gradients with respect to the module parameters, and maps them to the corresponding parameters.\n",
    "This [blog](https://medium.com/pytorch/differential-privacy-series-part-2-efficient-per-sample-gradient-computation-in-opacus-5bf4031d9e22) discusses the implementation and the math behind it in detail."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mm6x75oIMcU8"
   },
   "source": [
    "### Registering a grad_sampler\n",
    "But how do you tell Opacus this is the grad_sampler? That's simple, you simply decorate it with `register_grad_sampler`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "l3wJw5GYOo-W"
   },
   "outputs": [],
   "source": [
    "from opacus.grad_sample import register_grad_sampler\n",
    "\n",
    "\n",
    "@register_grad_sampler(nn.Linear)\n",
    "def compute_linear_grad_sample(\n",
    "    layer: nn.Linear, activations: torch.Tensor, backprops: torch.Tensor\n",
    ") -> Dict[nn.Parameter, torch.Tensor]:\n",
    "    \"\"\"\n",
    "    Computes per sample gradients for ``nn.Linear`` layer\n",
    "    Args:\n",
    "        layer: Layer\n",
    "        activations: Activations\n",
    "        backprops: Backpropagations\n",
    "    \"\"\"\n",
    "    gs = torch.einsum(\"n...i,n...j->nij\", backprops, activations)\n",
    "    ret = {layer.weight: gs}\n",
    "    if layer.bias is not None:\n",
    "        ret[layer.bias] = torch.einsum(\"n...k->nk\", backprops)\n",
    "\n",
    "    return ret"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5HTyhM6WO3DY"
   },
   "source": [
    "Once again, that's it! No really, check out the [code](https://github.com/pytorch/opacus/blob/main/opacus/grad_sample/linear.py) at is literally just this.\n",
    "\n",
    "The `register_grad_sampler` defined in [`grad_sample/utils`](https://github.com/pytorch/opacus/blob/main/opacus/grad_sample/utils.py) registers the function as a grad_sampler for `nn.Linear` (which is passed as an arg to the decorator). The `GradSampleModule` maintains a [register](https://github.com/pytorch/opacus/blob/main/opacus/grad_sample/grad_sample_module.py#L64) of all the grad_samplers and their corresponding modules.\n",
    "\n",
    "If you want to register a custom grad_sampler, all you have to do is decorate your function as shown above. Note that the order of registration matters; if you register more than one grad_sampler for a certain module, the last one wins."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ylUOxcNlk5rb"
   },
   "source": [
    "### Supported modules\n",
    "Opacus offers grad_samplers for most common modules; you can see the full list [here](https://github.com/pytorch/opacus/tree/main/opacus/grad_sample). As you can see, this list is not at all exhaustive; we wholeheartedly welcome your contributions.\n",
    "\n",
    "By design, the `GradSampleModule` just does that - computes grad samples. While it is built for use with Opacus, it certainly isn't restricted to DP use cases and can be used for any task that needs per-sample-gradients.\n",
    "\n",
    "If you have any questions or comments, please don't hesitate to post them on our [forum](https://discuss.pytorch.org/c/opacus/29)."
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Per-sample-gradients correctness utility\n",
    "[Here](https://github.com/pytorch/opacus/blob/main/opacus/utils/per_sample_gradients_utils.py) you can find a simple utility function `check_per_sample_gradients_are_correct` that checks if the gradient sampler works correctly with a particular module."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "x_shape = [N, Z, W]\n",
    "x = torch.randn(x_shape)\n",
    "model = nn.Linear(W, W + 2)\n",
    "assert check_per_sample_gradients_are_correct(\n",
    "        x,\n",
    "        model\n",
    "    ) # This will fail only if the opacus per sample gradients do not match the micro-batch gradients."
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
